
<!DOCTYPE html>
<html>
  <head>
    
<meta charset="utf-8" >

<title>XDL 的特征准入与退出机制 | dragon</title>
<meta name="description" content="邮箱(base64)：MTY5MDMwMjk2M0BxcS5jb20=
">

<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/3.7.0/animate.min.css">

<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.2/css/all.css" integrity="sha384-fnmOCqbTlWIlj8LyTjo7mOUStjsKC4pOpQbqyi7RrhN7udi9RwhKkMHpvLbHG9Sr" crossorigin="anonymous">
<link rel="shortcut icon" href="https://dragonfive.gitee.io//favicon.ico?v=1740893463017">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.10.0/katex.min.css">
<link rel="stylesheet" href="https://dragonfive.gitee.io//styles/main.css">



<script src="https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script>
<script src="//cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.5.1/build/highlight.min.js"></script>



  </head>
  <body>
    <div id="app" class="main">
      <div class="site-header-container">
  <div class="site-header">
    <div class="left">
      <a href="https://dragonfive.gitee.io/">
        <img class="avatar" src="https://dragonfive.gitee.io//images/avatar.png?v=1740893463017" alt="" width="32px" height="32px">
      </a>
      <a href="https://dragonfive.gitee.io/">
        <h1 class="site-title">dragon</h1>
      </a>
    </div>
    <div class="right">
      <transition name="fade">
        <i class="icon" :class="{ 'icon-close-outline': menuVisible, 'icon-menu-outline': !menuVisible }" @click="menuVisible = !menuVisible"></i>
      </transition>
    </div>
  </div>
</div>

<transition name="fade">
  <div class="menu-container" style="display: none;" v-show="menuVisible">
    <div class="menu-list">
      
        
          <a href="/" class="menu purple-link">
            首页
          </a>
        
      
        
          <a href="/archives" class="menu purple-link">
            归档
          </a>
        
      
        
          <a href="/tags" class="menu purple-link">
            标签
          </a>
        
      
        
          <a href="/post/about" class="menu purple-link">
            关于
          </a>
        
      
    </div>
  </div>
</transition>


      <div class="content-container">
        <div class="post-detail">
          
          <h2 class="post-title">XDL 的特征准入与退出机制</h2>
          <div class="post-info post-detail-info">
            <span><i class="icon-calendar-outline"></i> 2019-11-08</span>
            
              <span>
                <i class="icon-pricetags-outline"></i>
                
                  <a href="https://dragonfive.gitee.io/tag/xdl/">
                    XDL
                    
                      ，
                    
                  </a>
                
                  <a href="https://dragonfive.gitee.io/tag/WzibKNMac/">
                    深度学习
                    
                  </a>
                
              </span>
            
          </div>
          <div class="post-content" v-pre>
            <p>本文主要介绍 xdl 中的 embedding 和 statis 的底层数据结构。</p>
<p>embedding 和 statis 其实都是保存在ps 端上的 variable，包含 data 和 slot 部分。</p>
<p>只不过对于statis来说主体是__decay_rate 这个slot，data部分是shape 为[featuredim,1]，初始化方式为zeros的tensor，并没有被用到。</p>
<p>在 ps server 端，variable 的结构如下：</p>
<pre><code class="language-c++">class Variable {
 
 
 private:
  std::unique_ptr&lt;Tensor&gt; data_;//emd_dim的参数
  std::unique_ptr&lt;Data&gt; slicer_;//hashmap,从hashkey到id的映射
  std::unordered_map&lt;std::string, Slot&gt; slots_;//slot_name到对应slot
 public:
  struct Slot {
    std::unique_ptr&lt;Tensor&gt; tensor;
    SlotJoiner joiner;
  };
};
</code></pre>
<p>hashmap 管理特征准入和退出策略，我们可以从函数签名中看到 ，建立hashkey到id的映射关系，ids存的是我们输入的key在data这个tensor中的位置</p>
<p>add_probability 会在这里起作用，进行特征准入，reused_ids会把一些id剔除</p>
<h1 id="特征准入">特征准入</h1>
<pre><code class="language-c++">int64_t Get(const int64_t* keys,
                 size_t size, bool not_insert,
                 float add_probability,
                 std::vector&lt;size_t&gt;* ids,
                 tbb::concurrent_vector&lt;size_t&gt;* reused_ids,
                 size_t* filtered_keys,
                 size_t block_size = 500, size_t timeout = 1800)
</code></pre>
<p>其中 add_probability 就是以一定概率对 feature key 进行准入。</p>
<h1 id="特征退出">特征退出</h1>
<pre><code class="language-c++">if (max_id &gt; 0) {
  PS_CHECK_STATUS(variable-&gt;ReShapeId(max_id));
}
if (reused_ids.size() != 0) {
  std::vector&lt;size_t&gt; raw_reused_ids;
  for (auto iter : reused_ids) {
    raw_reused_ids.push_back(iter);
  }
  variable-&gt;ClearIds(raw_reused_ids);
}
</code></pre>
<p>同一个 variable，在ps会通过均衡策略保存不同的部分到不同的ps上。<br>
向ps请求一个variable的指定key的时候，就需要计算分别向哪些ps请求哪些key，这部分工作是在client端的partition代码完成的</p>
<p>以初始化PsConstantInitializerOp为例,在client端调用HashInitializer，在client中调用partitioner::HashShape来计算在各个ps上的形状</p>
<pre><code class="language-c++">
void Client::HashInitializer(const std::string&amp; variable_name,
                             Initializer* init,
                             const Client::Callback&amp; cb) {
  VariableInfo info;
  CHECK_ASYNC(GetVariableInfo(variable_name, &amp;info));
  std::string extra_info;
  for (auto&amp; arg : info.args) {
    extra_info += arg.first + &quot;=&quot; + arg.second + &quot;&amp;&quot;;
  }
  if (!extra_info.empty()) { extra_info.pop_back(); }
  std::vector&lt;Data*&gt; inputs = Args(0, 0, extra_info, std::unique_ptr&lt;Initializer&gt;(init));
  std::vector&lt;std::unique_ptr&lt;Data&gt;&gt;* outputs =
    new std::vector&lt;std::unique_ptr&lt;Data&gt;&gt;;
  std::vector&lt;Partitioner*&gt; splitter = {
    new partitioner::HashDataType,
    new partitioner::HashShape,
    new partitioner::Broadcast,
    new partitioner::Broadcast
  };
  std::vector&lt;Partitioner*&gt; combiner = {};
  UdfData udf(&quot;HashVariableInitializer&quot;, UdfData(0), UdfData(1), UdfData(2), UdfData(3));
  ...
}
 
 
Status HashShape::Split(PartitionerContext* ctx, Data* src, std::vector&lt;Data*&gt;* dst) {
  VariableInfo* info = ctx-&gt;GetVariableInfo();
  std::vector&lt;size_t&gt; dims(info-&gt;shape.begin(), info-&gt;shape.end());
  dst-&gt;clear();
  for (size_t i = 0; i &lt; info-&gt;parts.size(); i++) {
    size_t k = info-&gt;parts[i].size * info-&gt;shape[0] / Hasher::kTargetRange;
    dims[0] = k + 10 * sqrt(k) + 10;
    Data* d = new WrapperData&lt;TensorShape&gt;(dims);
    ctx-&gt;AddDeleter(d);
    dst-&gt;push_back(d);
  }
  return Status::Ok();
}

</code></pre>
<p>client 通过 udf 通知ps做什么样的操作，比如这个初始化就是告诉 ps server 调用 HashVariableInitializer 的simplerun方法进行初始化</p>
<pre><code class="language-c++">class HashVariableInitializer : public SimpleUdf&lt;DataType, TensorShape, std::string, std::unique_ptr&lt;Initializer&gt;&gt; {
 public:
  virtual Status SimpleRun(
      UdfContext* ctx,
      const DataType&amp; dt,
      const TensorShape&amp; shape,
      const std::string&amp; extra_info,
      const std::unique_ptr&lt;Initializer&gt;&amp; initializer) const {
    ...
    std::vector&lt;std::string&gt; slots;
    ...
    ps::Status status = GetStorageManager(ctx)-&gt;Get(var_name, &amp;var);
    ...
    if (!status.IsOk()) {
      return ctx-&gt;GetStorageManager()-&gt;Set(var_name, [&amp;]{
            HashMap* hashmap = new HashMapImpl&lt;int64_t&gt;(init_shape[0]);
            Variable* var = new Variable(new Tensor(dt, init_shape, initializer-&gt;Clone(), Tensor::TType::kSegment, true), new WrapperData&lt;std::unique_ptr&lt;HashMap&gt; &gt;(hashmap), var_name);
            Status st = InitSlots(slots, var);
            ...
            return var;
          });
    } else {
      ...
  }
 
 
  Status InitSlots(const std::vector&lt;std::string&gt;&amp; slots, Variable* var) const {
    for (auto&amp; slot : slots) {
      ...
      Tensor* t = var-&gt;GetVariableLikeSlot(tuple[0], dtype, TensorShape(inner_dims), []{ return new initializer::ConstantInitializer(0); });
      ...
    return Status::Ok();
  }
</code></pre>

          </div>
        </div>

        
          <div class="next-post">
            <a class="purple-link" href="https://dragonfive.gitee.io/post/tensorflow-1x-de-fen-bu-shi-xun-lian/">
              <h3 class="post-title">
                下一篇：tensorflow 1.x 的分布式训练
              </h3>
            </a>
          </div>
          
      </div>

      

      <div class="site-footer">
  <div class="slogan">邮箱(base64)：MTY5MDMwMjk2M0BxcS5jb20=
</div>
  <div class="social-container">
    
      
        <a href="https://github.com/DragonFive" target="_blank">
          <i class="fab fa-github"></i>
        </a>
      
    
      
    
      
    
      
    
      
    
  </div>
  Powered by <a href="https://github.com/getgridea/gridea" target="_blank">Gridea</a> | <a class="rss" href="https://dragonfive.gitee.io//atom.xml" target="_blank">RSS</a>
</div>


    </div>
    <script type="application/javascript">

hljs.initHighlightingOnLoad()

var app = new Vue({
  el: '#app',
  data: {
    menuVisible: false,
  },
})

</script>




  </body>
</html>
